{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "ds = load_dataset(\"RUC-AIBOX/STILL-3-Preview-RL-Data\")\n",
    "# Convert ds into json list of dicotinaries\n",
    "still_ds = ds['train'].map(lambda x: {\"problem\": x[\"question\"], \"answer\": x[\"answer\"],})\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load all jsons in ../test and merge them into one large json\n",
    "import os\n",
    "import json\n",
    "\n",
    "prior_data = []\n",
    "for file in ['aime.json', 'amc.json', 'math.json', 'omni_math.json']:\n",
    "    # get absolute path by merging file wtih directory\n",
    "    file_path = os.path.join(\"../train\", file)\n",
    "    file_path = os.path.abspath(file_path)\n",
    "    with open(file_path, \"r\") as f:\n",
    "        data = json.load(f)\n",
    "    prior_data.extend(data)\n",
    "    \n",
    "len(prior_data)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "from deepscaler.utils import RAG\n",
    "\n",
    "rag_searcher = RAG(docs=[d[\"problem\"] for d in prior_data])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Filter for olympiad problems that are not in the omni dataset\n",
    "from tqdm import tqdm\n",
    "\n",
    "filter_problems = []\n",
    "num_problems = 0\n",
    "\n",
    "counter = 0\n",
    "# Wrap olympiad_data with tqdm, optionally adding a description and total\n",
    "for d in tqdm(still_ds, desc=\"Filtering stil data\", total=len(still_ds)):\n",
    "    search_result = rag_searcher.top_k(d[\"problem\"], k=1)[0]\n",
    "    score = search_result[\"score\"]\n",
    "    if score > 0.90:\n",
    "        num_problems += 1\n",
    "    else:\n",
    "        filter_problems.append(d)\n",
    "    counter += 1\n",
    "    if counter %1000 == 0:\n",
    "        print(counter)\n",
    "# Save final list as json\n",
    "with open(\"still.json\", \"w\") as f:\n",
    "    json.dump(filter_problems, f, indent=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for d in filter_problems:\n",
    "    del d[\"question\"]\n",
    "    del d['messages']\n",
    "\n",
    "# Save final list as json\n",
    "with open(\"still.json\", \"w\") as f:\n",
    "    json.dump(filter_problems, f, indent=2)\n",
    "len(filter_problems)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "deepscaler",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
